from torch import nn
from .autoencoders import VAELatentLayer, SplitLatentLayer, MergeLatentLayer, GMVAELatentLayer
from .adversarial import AdversarialLatentLayer
from .vqvae import VQVAELatentLayer
import logging
from abc import ABC
from abc import abstractmethod

# class AbstractLatentLayer(ABC):s
#     def __init__(self):
#         self.is_adverserial = True
#
#     @abstractmethod
#     def forward(self, h, g):
#         pass

def create_latent_layer(**config) -> nn.Module:
    if config['type'] == 'adversarial':
        return AdversarialLatentLayer(**config)
    elif config['type'] == 'vae':
        return VAELatentLayer(**config)
    elif config['type'] == 'split':
        return SplitLatentLayer(**config)
    elif config['type'] == 'merge':
        return MergeLatentLayer(**config)
    elif config['type'] == 'gmvae':
        return GMVAELatentLayer(**config)
    elif config['type'] == 'vqvae':
        return VQVAELatentLayer(**config)
    else:
        raise ValueError(f"Unrecognized latent model name: {config['type']}")

class PlaceholderLayer(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.is_adversarial = False

    def forward(self, x_dict):
        return x_dict['h'], 0

class LatentModel(nn.Module):
    def __init__(self, configs=None):
        super().__init__()
        self.layers = nn.ModuleList([PlaceholderLayer()])
        if configs is not None:
            for c in configs:
                self.layers.append(create_latent_layer(**c))

    def forward(self, x_dict):
        total_loss = 0
        for layer in self.layers:
            x_dict['h'], loss = layer(x_dict)
            total_loss += loss
        return x_dict['h'], total_loss

    def add_layer(self, **config):
        self.layers.append(create_latent_layer(**config))

    def d_train(self, x_dict):
        for layer in self.layers:
            if layer.is_adversarial:
                layer.d_iter(x_dict)
